/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/

/*! \file
    \brief Reference implementation for convolution in device-side code.
*/

#pragma once

#include "cutlass/coord.h"
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/conv2d_problem_size.h"
#include "cutlass/conv/conv3d_problem_size.h"

namespace cutlass {
namespace reference {
namespace device {

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace kernel {

////////////////////////////////////////////////////////////////////////////////////////////////////
///                                   Conv2d device reference kernel
////////////////////////////////////////////////////////////////////////////////////////////////////

// Conv2d Fprop kernel - y = fprop(x, w)
template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator = ElementCompute,
          typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
          typename InnerProductOp = multiply_add<ElementAccumulator>,
          int kThreadM = 2,  // shape of a thread's tile in the GEMM M dimension
          int kThreadN = 4,  // shape of a thread's tile in the GEMM N dimension
          int kCtaShapeM = 16,  // shape of a threadblock in units of threads
          int kCtaShapeN = 8    // shape of a threadblock in units of threads
          >
__global__ void Conv2dFprop(conv::Conv2dProblemSize problem_size,
                            TensorRef<ElementA, LayoutA> tensor_x,
                            TensorRef<ElementB, LayoutB> tensor_w,
                            TensorRef<ElementC, LayoutC> tensor_y_in,
                            TensorRef<ElementC, LayoutC> tensor_y_out,
                            ElementCompute alpha, ElementCompute beta) {
    ConvertOp convert_op;
    InnerProductOp inner_product_op;

    ElementAccumulator element_A[kThreadM];
    ElementAccumulator element_B[kThreadN];
    ElementAccumulator accum[kThreadM][kThreadN];

    int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM +
                        threadIdx.x * kThreadM;
    int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;

    int thread_n[kThreadM];
    int thread_p[kThreadM];
    int thread_q[kThreadM];

    // Compute N, P, Q coordinates for each row of a thread's tile
    int64_t PQ = int64_t(problem_size.P) * problem_size.Q;

    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        int64_t npq = npq_start + m;

        thread_n[m] = int(npq / PQ);

        int64_t residual = npq % PQ;
        thread_p[m] = int(residual / problem_size.Q);
        thread_q[m] = int(residual % problem_size.Q);
    }

    // Clear accumulators
    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        CUTLASS_PRAGMA_UNROLL
        for (int n = 0; n < kThreadN; ++n) {
            accum[m][n] = ElementAccumulator();
        }
    }

    // Compute convolution
    for (int R = 0; R < problem_size.R; ++R) {
        for (int S = 0; S < problem_size.S; ++S) {
            for (int C = 0; C < problem_size.C; ++C) {
                // Load from activations tensor
                int filter_r = R;
                int filter_s = S;

                if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
                    filter_r = problem_size.R - 1 - R;
                    filter_s = problem_size.S - 1 - S;
                }

                CUTLASS_PRAGMA_UNROLL
                for (int m = 0; m < kThreadM; ++m) {
                    int h = thread_p[m] * problem_size.stride_h -
                            problem_size.pad_h +
                            filter_r * problem_size.dilation_h;
                    int w = thread_q[m] * problem_size.stride_w -
                            problem_size.pad_w +
                            filter_s * problem_size.dilation_w;

                    if (thread_n[m] < problem_size.N && h >= 0 &&
                        h < problem_size.H && w >= 0 && w < problem_size.W) {
                        element_A[m] = ElementAccumulator(
                                tensor_x.at({thread_n[m], h, w, C}));
                    } else {
                        element_A[m] = ElementAccumulator();
                    }
                }

                // Load from filters tensor
                CUTLASS_PRAGMA_UNROLL
                for (int n = 0; n < kThreadN; ++n) {
                    int thread_k = k_start + n;

                    if (thread_k < problem_size.K) {
                        element_B[n] = ElementAccumulator(
                                tensor_w.at({thread_k, R, S, C}));
                    } else {
                        element_B[n] = ElementAccumulator();
                    }
                }

                // Accumulate matrix product
                CUTLASS_PRAGMA_UNROLL
                for (int m = 0; m < kThreadM; ++m) {
                    CUTLASS_PRAGMA_UNROLL
                    for (int n = 0; n < kThreadN; ++n) {
                        accum[m][n] = inner_product_op(
                                element_A[m], element_B[n], accum[m][n]);
                    }
                }
            }
        }
    }

    // Write out the results
    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P &&
            thread_q[m] < problem_size.Q) {
            CUTLASS_PRAGMA_UNROLL
            for (int n = 0; n < kThreadN; ++n) {
                int thread_k = k_start + n;
                if (thread_k < problem_size.K) {
                    ElementCompute c_ref = ElementCompute();
                    if (beta != ElementCompute()) {
                        c_ref = ElementCompute(
                                tensor_y_in.at({thread_n[m], thread_p[m],
                                                thread_q[m], thread_k}));
                    }

                    tensor_y_out.at(
                            {thread_n[m], thread_p[m], thread_q[m], thread_k}) =
                            convert_op(alpha * ElementCompute(accum[m][n]) +
                                       beta * c_ref);
                }
            }
        }
    }
}

// Conv3d Fprop kernel - y = fprop(x, w)
template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator = ElementCompute,
          typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
          typename InnerProductOp = multiply_add<ElementAccumulator>,
          int kThreadM = 2,  // shape of a thread's tile in the GEMM M dimension
          int kThreadN = 4,  // shape of a thread's tile in the GEMM N dimension
          int kCtaShapeM = 16,  // shape of a threadblock in units of threads
          int kCtaShapeN = 8    // shape of a threadblock in units of threads
          >
__global__ void Conv3dFprop(conv::Conv3dProblemSize problem_size,
                            TensorRef<ElementA, LayoutA> tensor_x,
                            TensorRef<ElementB, LayoutB> tensor_w,
                            TensorRef<ElementC, LayoutC> tensor_y_in,
                            TensorRef<ElementC, LayoutC> tensor_y_out,
                            ElementCompute alpha, ElementCompute beta) {
    ConvertOp convert_op;
    InnerProductOp inner_product_op;

    ElementAccumulator element_A[kThreadM];
    ElementAccumulator element_B[kThreadN];
    ElementAccumulator accum[kThreadM][kThreadN];

    int64_t nzpq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM +
                         threadIdx.x * kThreadM;
    int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;

    int thread_n[kThreadM];
    int thread_z[kThreadM];
    int thread_p[kThreadM];
    int thread_q[kThreadM];

    // Compute N, Z, P, Q coordinates for each row of a thread's tile
    int64_t PQ = int64_t(problem_size.P) * problem_size.Q;
    int64_t ZPQ = PQ * problem_size.Z;

    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        int64_t nzpq = nzpq_start + m;

        thread_n[m] = int(nzpq / ZPQ);

        int64_t residual = nzpq % ZPQ;
        thread_z[m] = int(residual / PQ);

        residual = residual % PQ;
        thread_p[m] = int(residual / problem_size.Q);
        thread_q[m] = int(residual % problem_size.Q);
    }

    // Clear accumulators
    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        CUTLASS_PRAGMA_UNROLL
        for (int n = 0; n < kThreadN; ++n) {
            accum[m][n] = ElementAccumulator();
        }
    }

    // Compute convolution
    for (int T = 0; T < problem_size.T; ++T) {
        for (int R = 0; R < problem_size.R; ++R) {
            for (int S = 0; S < problem_size.S; ++S) {
                for (int C = 0; C < problem_size.C; ++C) {
                    // Load from activations tensor
                    int filter_t = T;
                    int filter_r = R;
                    int filter_s = S;

                    if (problem_size.mode ==
                        cutlass::conv::Mode::kConvolution) {
                        filter_t = problem_size.T - 1 - R;
                        filter_r = problem_size.R - 1 - R;
                        filter_s = problem_size.S - 1 - S;
                    }

                    CUTLASS_PRAGMA_UNROLL
                    for (int m = 0; m < kThreadM; ++m) {
                        int d = thread_z[m] * problem_size.stride_d -
                                problem_size.pad_d +
                                filter_t * problem_size.dilation_d;
                        int h = thread_p[m] * problem_size.stride_h -
                                problem_size.pad_h +
                                filter_r * problem_size.dilation_h;
                        int w = thread_q[m] * problem_size.stride_w -
                                problem_size.pad_w +
                                filter_s * problem_size.dilation_w;

                        if (thread_n[m] < problem_size.N && d >= 0 &&
                            d < problem_size.D && h >= 0 &&
                            h < problem_size.H && w >= 0 &&
                            w < problem_size.W) {
                            element_A[m] = ElementAccumulator(
                                    tensor_x.at({thread_n[m], d, h, w, C}));
                        } else {
                            element_A[m] = ElementAccumulator();
                        }
                    }

                    // Load from filters tensor
                    CUTLASS_PRAGMA_UNROLL
                    for (int n = 0; n < kThreadN; ++n) {
                        int thread_k = k_start + n;

                        if (thread_k < problem_size.K) {
                            element_B[n] = ElementAccumulator(
                                    tensor_w.at({thread_k, T, R, S, C}));
                        } else {
                            element_B[n] = ElementAccumulator();
                        }
                    }

                    // Accumulate matrix product
                    CUTLASS_PRAGMA_UNROLL
                    for (int m = 0; m < kThreadM; ++m) {
                        CUTLASS_PRAGMA_UNROLL
                        for (int n = 0; n < kThreadN; ++n) {
                            accum[m][n] = inner_product_op(
                                    element_A[m], element_B[n], accum[m][n]);
                        }
                    }

                }  // for (C)
            }      // for (S)
        }          // for (R)
    }              // for (T)

    // Write out the results
    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        if (thread_n[m] < problem_size.N && thread_z[m] < problem_size.Z &&
            thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) {
            CUTLASS_PRAGMA_UNROLL
            for (int n = 0; n < kThreadN; ++n) {
                int thread_k = k_start + n;
                if (thread_k < problem_size.K) {
                    ElementCompute c_ref = ElementCompute();
                    if (beta != ElementCompute()) {
                        c_ref = ElementCompute(tensor_y_in.at(
                                {thread_n[m], thread_z[m], thread_p[m],
                                 thread_q[m], thread_k}));
                    }

                    tensor_y_out.at({thread_n[m], thread_z[m], thread_p[m],
                                     thread_q[m], thread_k}) =
                            convert_op(alpha * ElementCompute(accum[m][n]) +
                                       beta * c_ref);
                }
            }  // for (n)
        }
    }  // for (m)
}

///////////////////////////////////////////////////////////////////////////////////////////////////

// Conv2d dgrad kernel - dx = dgrad(dy, w)
template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator = ElementCompute,
          typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
          typename InnerProductOp = multiply_add<ElementAccumulator>,
          int kThreadM = 2,  // shape of a thread's tile in the GEMM M dimension
          int kThreadN = 4,  // shape of a thread's tile in the GEMM N dimension
          int kCtaShapeM = 16,  // shape of a threadblock in units of threads
          int kCtaShapeN = 8    // shape of a threadblock in units of threads
          >
__global__ void Conv2dDgrad(conv::Conv2dProblemSize problem_size,
                            TensorRef<ElementA, LayoutA> tensor_dy,
                            TensorRef<ElementB, LayoutB> tensor_w,
                            TensorRef<ElementC, LayoutC> tensor_dx_in,
                            TensorRef<ElementC, LayoutC> tensor_dx_out,
                            ElementCompute alpha, ElementCompute beta) {
    ConvertOp convert_op;
    InnerProductOp inner_product_op;

    ElementAccumulator element_A[kThreadM];
    ElementAccumulator element_B[kThreadN];
    ElementAccumulator accum[kThreadM][kThreadN];

    int64_t nhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM +
                        threadIdx.x * kThreadM;
    int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;

    int thread_n[kThreadM];
    int thread_h[kThreadM];
    int thread_w[kThreadM];

    // Compute N, H, W coordinates for each row of a thread's tile
    int64_t HW = int64_t(problem_size.H) * problem_size.W;

    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        int64_t nhw = nhw_start + m;

        thread_n[m] = int(nhw / HW);

        int64_t residual = nhw % HW;
        thread_h[m] = int(residual / problem_size.W);
        thread_w[m] = int(residual % problem_size.W);
    }

    // Clear accumulators
    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        CUTLASS_PRAGMA_UNROLL
        for (int n = 0; n < kThreadN; ++n) {
            accum[m][n] = ElementAccumulator();
        }
    }

    // Compute convolution
    for (int R = 0; R < problem_size.R; ++R) {
        for (int S = 0; S < problem_size.S; ++S) {
            for (int K = 0; K < problem_size.K; ++K) {
                // Load from activations tensor
                int filter_r = R;
                int filter_s = S;

                if (problem_size.mode == cutlass::conv::Mode::kConvolution) {
                    filter_r = problem_size.R - 1 - R;
                    filter_s = problem_size.S - 1 - S;
                }

                CUTLASS_PRAGMA_UNROLL
                for (int m = 0; m < kThreadM; ++m) {
                    int p = thread_h[m] + problem_size.pad_h -
                            filter_r * problem_size.dilation_h;
                    int q = thread_w[m] + problem_size.pad_w -
                            filter_s * problem_size.dilation_w;

                    element_A[m] = ElementAccumulator();

                    if (p >= 0 && !(p % problem_size.stride_h) && q >= 0 &&
                        !(q % problem_size.stride_w)) {
                        p = p / problem_size.stride_h;
                        q = q / problem_size.stride_w;

                        if (thread_n[m] < problem_size.N &&
                            p < problem_size.P && q < problem_size.Q) {
                            element_A[m] = ElementAccumulator(
                                    tensor_dy.at({thread_n[m], p, q, K}));
                        }
                    }
                }

                // Load from filters tensor
                CUTLASS_PRAGMA_UNROLL
                for (int n = 0; n < kThreadN; ++n) {
                    int thread_c = c_start + n;

                    if (thread_c < problem_size.C) {
                        element_B[n] = ElementAccumulator(
                                tensor_w.at({K, R, S, thread_c}));
                    } else {
                        element_B[n] = ElementAccumulator();
                    }
                }

                // Accumulate matrix product
                CUTLASS_PRAGMA_UNROLL
                for (int m = 0; m < kThreadM; ++m) {
                    CUTLASS_PRAGMA_UNROLL
                    for (int n = 0; n < kThreadN; ++n) {
                        accum[m][n] = inner_product_op(
                                element_A[m], element_B[n], accum[m][n]);
                    }
                }
            }
        }
    }

    // Write out the results
    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        if (thread_n[m] < problem_size.N && thread_h[m] < problem_size.H &&
            thread_w[m] < problem_size.W) {
            CUTLASS_PRAGMA_UNROLL
            for (int n = 0; n < kThreadN; ++n) {
                int thread_c = c_start + n;
                if (thread_c < problem_size.C) {
                    ElementCompute c_ref = ElementCompute();
                    if (beta != ElementCompute()) {
                        c_ref = ElementCompute(
                                tensor_dx_in.at({thread_n[m], thread_h[m],
                                                 thread_w[m], thread_c}));
                    }

                    tensor_dx_out.at(
                            {thread_n[m], thread_h[m], thread_w[m], thread_c}) =
                            convert_op(alpha * ElementCompute(accum[m][n]) +
                                       beta * c_ref);
                }
            }
        }
    }
}

// Conv3d dgrad kernel - dx = dgrad(dy, w)
template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator = ElementCompute,
          typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
          typename InnerProductOp = multiply_add<ElementAccumulator>,
          int kThreadM = 2,  // shape of a thread's tile in the GEMM M dimension
          int kThreadN = 4,  // shape of a thread's tile in the GEMM N dimension
          int kCtaShapeM = 16,  // shape of a threadblock in units of threads
          int kCtaShapeN = 8    // shape of a threadblock in units of threads
          >
__global__ void Conv3dDgrad(conv::Conv3dProblemSize problem_size,
                            TensorRef<ElementA, LayoutA> tensor_dy,
                            TensorRef<ElementB, LayoutB> tensor_w,
                            TensorRef<ElementC, LayoutC> tensor_dx_in,
                            TensorRef<ElementC, LayoutC> tensor_dx_out,
                            ElementCompute alpha, ElementCompute beta) {
    ConvertOp convert_op;
    InnerProductOp inner_product_op;

    ElementAccumulator element_A[kThreadM];
    ElementAccumulator element_B[kThreadN];
    ElementAccumulator accum[kThreadM][kThreadN];

    int64_t ndhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM +
                         threadIdx.x * kThreadM;
    int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN;

    int thread_n[kThreadM];
    int thread_d[kThreadM];
    int thread_h[kThreadM];
    int thread_w[kThreadM];

    // Compute N, H, W coordinates for each row of a thread's tile
    int64_t HW = int64_t(problem_size.H) * problem_size.W;
    int64_t DHW = HW * problem_size.D;

    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        int64_t ndhw = ndhw_start + m;

        thread_n[m] = int(ndhw / DHW);

        int64_t residual = ndhw % DHW;
        thread_d[m] = int(residual / HW);

        residual = residual % HW;
        thread_h[m] = int(residual / problem_size.W);
        thread_w[m] = int(residual % problem_size.W);
    }

    // Clear accumulators
    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        CUTLASS_PRAGMA_UNROLL
        for (int n = 0; n < kThreadN; ++n) {
            accum[m][n] = ElementAccumulator();
        }
    }

    // Compute convolution
    for (int T = 0; T < problem_size.T; ++T) {
        for (int R = 0; R < problem_size.R; ++R) {
            for (int S = 0; S < problem_size.S; ++S) {
                for (int K = 0; K < problem_size.K; ++K) {
                    // Load from activations tensor
                    int filter_t = T;
                    int filter_r = R;
                    int filter_s = S;

                    if (problem_size.mode ==
                        cutlass::conv::Mode::kConvolution) {
                        filter_t = problem_size.T - 1 - T;
                        filter_r = problem_size.R - 1 - R;
                        filter_s = problem_size.S - 1 - S;
                    }

                    CUTLASS_PRAGMA_UNROLL
                    for (int m = 0; m < kThreadM; ++m) {
                        int z = thread_d[m] + problem_size.pad_d -
                                filter_t * problem_size.dilation_d;
                        int p = thread_h[m] + problem_size.pad_h -
                                filter_r * problem_size.dilation_h;
                        int q = thread_w[m] + problem_size.pad_w -
                                filter_s * problem_size.dilation_w;

                        element_A[m] = ElementAccumulator();

                        if (z >= 0 && !(z % problem_size.stride_d) && p >= 0 &&
                            !(p % problem_size.stride_h) && q >= 0 &&
                            !(q % problem_size.stride_w)) {
                            z = z / problem_size.stride_d;
                            p = p / problem_size.stride_h;
                            q = q / problem_size.stride_w;

                            if (thread_n[m] < problem_size.N &&
                                z < problem_size.Z && p < problem_size.P &&
                                q < problem_size.Q) {
                                element_A[m] = ElementAccumulator(tensor_dy.at(
                                        {thread_n[m], z, p, q, K}));
                            }
                        }
                    }

                    // Load from filters tensor
                    CUTLASS_PRAGMA_UNROLL
                    for (int n = 0; n < kThreadN; ++n) {
                        int thread_c = c_start + n;

                        if (thread_c < problem_size.C) {
                            element_B[n] = ElementAccumulator(
                                    tensor_w.at({K, T, R, S, thread_c}));
                        } else {
                            element_B[n] = ElementAccumulator();
                        }
                    }

                    // Accumulate matrix product
                    CUTLASS_PRAGMA_UNROLL
                    for (int m = 0; m < kThreadM; ++m) {
                        CUTLASS_PRAGMA_UNROLL
                        for (int n = 0; n < kThreadN; ++n) {
                            accum[m][n] = inner_product_op(
                                    element_A[m], element_B[n], accum[m][n]);
                        }
                    }

                }  // for (C)
            }      // for (S)
        }          // for (R)
    }              // for (T)

    // Write out the results
    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        if (thread_n[m] < problem_size.N && thread_d[m] < problem_size.D &&
            thread_h[m] < problem_size.H && thread_w[m] < problem_size.W) {
            CUTLASS_PRAGMA_UNROLL
            for (int n = 0; n < kThreadN; ++n) {
                int thread_c = c_start + n;
                if (thread_c < problem_size.C) {
                    ElementCompute c_ref = ElementCompute();
                    if (beta != ElementCompute()) {
                        c_ref = ElementCompute(tensor_dx_in.at(
                                {thread_n[m], thread_d[m], thread_h[m],
                                 thread_w[m], thread_c}));
                    }

                    tensor_dx_out.at({thread_n[m], thread_d[m], thread_h[m],
                                      thread_w[m], thread_c}) =
                            convert_op(alpha * ElementCompute(accum[m][n]) +
                                       beta * c_ref);
                }
            }
        }
    }
}

///////////////////////////////////////////////////////////////////////////////////////////////////

// Conv2d wgrad kernel - dw = wgrad(dy, x)
template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator = ElementCompute,
          typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
          typename InnerProductOp = multiply_add<ElementAccumulator>,
          int kThreadM = 2,  // shape of a thread's tile in the GEMM M dimension
          int kThreadN = 4,  // shape of a thread's tile in the GEMM N dimension
          int kCtaShapeM = 8,  // shape of a threadblock in units of threads
          int kCtaShapeN = 16  // shape of a threadblock in units of threads
          >
__global__ void Conv2dWgrad(conv::Conv2dProblemSize problem_size,
                            TensorRef<ElementA, LayoutA> tensor_dy,
                            TensorRef<ElementB, LayoutB> tensor_x,
                            TensorRef<ElementC, LayoutC> tensor_dw_in,
                            TensorRef<ElementC, LayoutC> tensor_dw_out,
                            ElementCompute alpha, ElementCompute beta) {
    ConvertOp convert_op;
    InnerProductOp inner_product_op;

    ElementAccumulator element_A[kThreadM];
    ElementAccumulator element_B[kThreadN];
    ElementAccumulator accum[kThreadM][kThreadN];

    int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
    int64_t rsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN +
                        threadIdx.y * kThreadN;

    int thread_r[kThreadN];
    int thread_s[kThreadN];
    int thread_c[kThreadN];

    // Compute R, S, C coordinates for each row of a thread's tile
    int64_t SC = int64_t(problem_size.S) * problem_size.C;

    CUTLASS_PRAGMA_UNROLL
    for (int n = 0; n < kThreadN; ++n) {
        int64_t rsc = rsc_start + n;
        int64_t residual = rsc % SC;

        thread_r[n] = int(rsc / SC);
        thread_s[n] = int(residual / problem_size.C);
        thread_c[n] = int(residual % problem_size.C);
    }

    // Clear accumulators
    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        CUTLASS_PRAGMA_UNROLL
        for (int n = 0; n < kThreadN; ++n) {
            accum[m][n] = ElementAccumulator();
        }
    }

    // Compute convolution
    for (int N = 0; N < problem_size.N; ++N) {
        for (int P = 0; P < problem_size.P; ++P) {
            for (int Q = 0; Q < problem_size.Q; ++Q) {
                CUTLASS_PRAGMA_UNROLL
                for (int m = 0; m < kThreadM; ++m) {
                    int thread_k = k_start + m;

                    element_A[m] = ElementAccumulator();

                    if (thread_k < problem_size.K) {
                        element_A[m] = ElementAccumulator(
                                tensor_dy.at({N, P, Q, thread_k}));
                    }
                }

                // Load from filters tensor
                CUTLASS_PRAGMA_UNROLL
                for (int n = 0; n < kThreadN; ++n) {
                    // Load from activations tensor
                    int filter_r = thread_r[n];
                    int filter_s = thread_s[n];

                    if (problem_size.mode ==
                        cutlass::conv::Mode::kConvolution) {
                        filter_r = problem_size.R - 1 - filter_r;
                        filter_s = problem_size.S - 1 - filter_s;
                    }

                    int h = P * problem_size.stride_h - problem_size.pad_h +
                            filter_r * problem_size.dilation_h;
                    int w = Q * problem_size.stride_w - problem_size.pad_w +
                            filter_s * problem_size.dilation_w;

                    element_B[n] = ElementAccumulator();

                    if (h >= 0 && h < problem_size.H && w >= 0 &&
                        w < problem_size.W && thread_c[n] < problem_size.C) {
                        element_B[n] = ElementAccumulator(
                                tensor_x.at({N, h, w, thread_c[n]}));
                    }
                }

                // Accumulate matrix product
                CUTLASS_PRAGMA_UNROLL
                for (int m = 0; m < kThreadM; ++m) {
                    CUTLASS_PRAGMA_UNROLL
                    for (int n = 0; n < kThreadN; ++n) {
                        accum[m][n] = inner_product_op(
                                element_A[m], element_B[n], accum[m][n]);
                    }
                }
            }
        }
    }

    // Write out the results
    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        int thread_k = k_start + m;

        if (thread_k < problem_size.K) {
            CUTLASS_PRAGMA_UNROLL
            for (int n = 0; n < kThreadN; ++n) {
                if (thread_r[n] < problem_size.R &&
                    thread_s[n] < problem_size.S &&
                    thread_c[n] < problem_size.C) {
                    ElementCompute c_ref = ElementCompute();

                    if (beta != ElementCompute()) {
                        c_ref = ElementCompute(
                                tensor_dw_in.at({thread_k, thread_r[n],
                                                 thread_s[n], thread_c[n]}));
                    }

                    tensor_dw_out.at(
                            {thread_k, thread_r[n], thread_s[n], thread_c[n]}) =
                            convert_op(alpha * ElementCompute(accum[m][n]) +
                                       beta * c_ref);
                }
            }
        }
    }
}

// Conv3d wgrad kernel - dw = wgrad(dy, x)
template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator = ElementCompute,
          typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
          typename InnerProductOp = multiply_add<ElementAccumulator>,
          int kThreadM = 2,  // shape of a thread's tile in the GEMM M dimension
          int kThreadN = 4,  // shape of a thread's tile in the GEMM N dimension
          int kCtaShapeM = 8,  // shape of a threadblock in units of threads
          int kCtaShapeN = 16  // shape of a threadblock in units of threads
          >
__global__ void Conv3dWgrad(conv::Conv3dProblemSize problem_size,
                            TensorRef<ElementA, LayoutA> tensor_dy,
                            TensorRef<ElementB, LayoutB> tensor_x,
                            TensorRef<ElementC, LayoutC> tensor_dw_in,
                            TensorRef<ElementC, LayoutC> tensor_dw_out,
                            ElementCompute alpha, ElementCompute beta) {
    ConvertOp convert_op;
    InnerProductOp inner_product_op;

    ElementAccumulator element_A[kThreadM];
    ElementAccumulator element_B[kThreadN];
    ElementAccumulator accum[kThreadM][kThreadN];

    int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM;
    int64_t trsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN +
                         threadIdx.y * kThreadN;

    int thread_t[kThreadN];
    int thread_r[kThreadN];
    int thread_s[kThreadN];
    int thread_c[kThreadN];

    // Compute R, S, C coordinates for each row of a thread's tile
    int64_t SC = int64_t(problem_size.S) * problem_size.C;
    int64_t RSC = SC * problem_size.R;

    CUTLASS_PRAGMA_UNROLL
    for (int n = 0; n < kThreadN; ++n) {
        int64_t trsc = trsc_start + n;

        thread_t[n] = int(trsc / RSC);

        int64_t residual = trsc % RSC;
        thread_r[n] = int(residual / SC);

        residual = residual % SC;
        thread_s[n] = int(residual / problem_size.C);
        thread_c[n] = int(residual % problem_size.C);
    }

    // Clear accumulators
    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        CUTLASS_PRAGMA_UNROLL
        for (int n = 0; n < kThreadN; ++n) {
            accum[m][n] = ElementAccumulator();
        }
    }

    // Compute convolution
    for (int N = 0; N < problem_size.N; ++N) {
        for (int Z = 0; Z < problem_size.Z; ++Z) {
            for (int P = 0; P < problem_size.P; ++P) {
                for (int Q = 0; Q < problem_size.Q; ++Q) {
                    CUTLASS_PRAGMA_UNROLL
                    for (int m = 0; m < kThreadM; ++m) {
                        int thread_k = k_start + m;

                        element_A[m] = ElementAccumulator();

                        if (thread_k < problem_size.K) {
                            element_A[m] = ElementAccumulator(
                                    tensor_dy.at({N, Z, P, Q, thread_k}));
                        }
                    }

                    // Load from filters tensor
                    CUTLASS_PRAGMA_UNROLL
                    for (int n = 0; n < kThreadN; ++n) {
                        // Load from activations tensor
                        int filter_t = thread_t[n];
                        int filter_r = thread_r[n];
                        int filter_s = thread_s[n];

                        if (problem_size.mode ==
                            cutlass::conv::Mode::kConvolution) {
                            filter_t = problem_size.T - 1 - filter_t;
                            filter_r = problem_size.R - 1 - filter_r;
                            filter_s = problem_size.S - 1 - filter_s;
                        }

                        int d = Z * problem_size.stride_d - problem_size.pad_w +
                                filter_t * problem_size.dilation_d;
                        int h = P * problem_size.stride_h - problem_size.pad_h +
                                filter_r * problem_size.dilation_h;
                        int w = Q * problem_size.stride_w - problem_size.pad_w +
                                filter_s * problem_size.dilation_w;

                        element_B[n] = ElementAccumulator();

                        if (d >= 0 && d < problem_size.D && h >= 0 &&
                            h < problem_size.H && w >= 0 &&
                            w < problem_size.W &&
                            thread_c[n] < problem_size.C) {
                            element_B[n] = ElementAccumulator(
                                    tensor_x.at({N, d, h, w, thread_c[n]}));
                        }
                    }

                    // Accumulate matrix product
                    CUTLASS_PRAGMA_UNROLL
                    for (int m = 0; m < kThreadM; ++m) {
                        CUTLASS_PRAGMA_UNROLL
                        for (int n = 0; n < kThreadN; ++n) {
                            accum[m][n] = inner_product_op(
                                    element_A[m], element_B[n], accum[m][n]);
                        }
                    }

                }  // for (Q)
            }      // for (P)
        }          // for (Z)
    }              // for (N)

    // Write out the results
    CUTLASS_PRAGMA_UNROLL
    for (int m = 0; m < kThreadM; ++m) {
        int thread_k = k_start + m;

        if (thread_k < problem_size.K) {
            CUTLASS_PRAGMA_UNROLL
            for (int n = 0; n < kThreadN; ++n) {
                if (thread_t[n] < problem_size.T &&
                    thread_r[n] < problem_size.R &&
                    thread_s[n] < problem_size.S &&
                    thread_c[n] < problem_size.C) {
                    ElementCompute c_ref = ElementCompute();

                    if (beta != ElementCompute()) {
                        c_ref = ElementCompute(tensor_dw_in.at(
                                {thread_k, thread_t[n], thread_r[n],
                                 thread_s[n], thread_c[n]}));
                    }

                    tensor_dw_out.at({thread_k, thread_t[n], thread_r[n],
                                      thread_s[n], thread_c[n]}) =
                            convert_op(alpha * ElementCompute(accum[m][n]) +
                                       beta * c_ref);
                }
            }
        }
    }
}

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace kernel

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Conv2d Fprop dispatcher - y = fprop(x, w)
template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator = ElementCompute,
          typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
          typename InnerProductOp = multiply_add<ElementAccumulator>>
Status Conv2dFprop(conv::Conv2dProblemSize problem_size,
                   TensorRef<ElementA, LayoutA> tensor_x,
                   TensorRef<ElementB, LayoutB> tensor_w,
                   TensorRef<ElementC, LayoutC> tensor_y_in,
                   TensorRef<ElementC, LayoutC> tensor_y_out,
                   ElementCompute alpha, ElementCompute beta,
                   cudaStream_t stream = nullptr) {
    //
    // Blocking factors improve performance of reference implementation
    //

    int const kThreadM = 4;  // shape of a thread's tile in the GEMM M dimension
    int const kThreadN = 4;  // shape of a thread's tile in the GEMM N dimension
    int const kCtaShapeM = 16;  // shape of a threadblock in units of threads
    int const kCtaShapeN = 8;   // shape of a threadblock in units of threads

    int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q;
    int64_t blocks_m =
            (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);

    dim3 block(kCtaShapeM, kCtaShapeN);
    dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) -
                                   1) / (kCtaShapeN * kThreadN));

    kernel::Conv2dFprop<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                        ElementCompute, ElementAccumulator, ConvertOp,
                        InnerProductOp, kThreadM, kThreadN, kCtaShapeM,
                        kCtaShapeN><<<grid, block, 0, stream>>>(
            problem_size, tensor_x, tensor_w, tensor_y_in, tensor_y_out, alpha,
            beta);

    cudaError_t result = cudaPeekAtLastError();
    if (result != cudaSuccess) {
        return Status::kErrorInternal;
    }

    return Status::kSuccess;
}

/// Conv3d Fprop dispatcher - y = fprop(x, w)
template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator = ElementCompute,
          typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
          typename InnerProductOp = multiply_add<ElementAccumulator>>
Status Conv3dFprop(conv::Conv3dProblemSize problem_size,
                   TensorRef<ElementA, LayoutA> tensor_x,
                   TensorRef<ElementB, LayoutB> tensor_w,
                   TensorRef<ElementC, LayoutC> tensor_y_in,
                   TensorRef<ElementC, LayoutC> tensor_y_out,
                   ElementCompute alpha, ElementCompute beta,
                   cudaStream_t stream = nullptr) {
    //
    // Blocking factors improve performance of reference implementation
    //

    int const kThreadM = 4;  // shape of a thread's tile in the GEMM M dimension
    int const kThreadN = 4;  // shape of a thread's tile in the GEMM N dimension
    int const kCtaShapeM = 16;  // shape of a threadblock in units of threads
    int const kCtaShapeN = 8;   // shape of a threadblock in units of threads

    int64_t nzpq = int64_t(problem_size.N) * problem_size.Z * problem_size.P *
                   problem_size.Q;
    int64_t blocks_m =
            (nzpq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);

    dim3 block(kCtaShapeM, kCtaShapeN);
    dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) -
                                   1) / (kCtaShapeN * kThreadN));

    kernel::Conv3dFprop<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                        ElementCompute, ElementAccumulator, ConvertOp,
                        InnerProductOp, kThreadM, kThreadN, kCtaShapeM,
                        kCtaShapeN><<<grid, block, 0, stream>>>(
            problem_size, tensor_x, tensor_w, tensor_y_in, tensor_y_out, alpha,
            beta);

    cudaError_t result = cudaPeekAtLastError();
    if (result != cudaSuccess) {
        return Status::kErrorInternal;
    }

    return Status::kSuccess;
}

/// Conv2d Dgrad dispatcher - dx = dgrad(dy, w)
template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator = ElementCompute,
          typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
          typename InnerProductOp = multiply_add<ElementAccumulator>>
Status Conv2dDgrad(conv::Conv2dProblemSize problem_size,
                   TensorRef<ElementA, LayoutA> tensor_dy,
                   TensorRef<ElementB, LayoutB> tensor_w,
                   TensorRef<ElementC, LayoutC> tensor_dx_in,
                   TensorRef<ElementC, LayoutC> tensor_dx_out,
                   ElementCompute alpha, ElementCompute beta,
                   cudaStream_t stream = nullptr) {
    //
    // Blocking factors improve performance of reference implementation
    //

    int const kThreadM = 2;  // shape of a thread's tile in the GEMM M dimension
    int const kThreadN = 4;  // shape of a thread's tile in the GEMM N dimension
    int const kCtaShapeM = 16;  // shape of a threadblock in units of threads
    int const kCtaShapeN = 8;   // shape of a threadblock in units of threads

    int64_t nhw = int64_t(problem_size.N) * problem_size.H * problem_size.W;
    int64_t blocks_m =
            (nhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);

    dim3 block(kCtaShapeM, kCtaShapeN);
    dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) -
                                   1) / (kCtaShapeN * kThreadN));

    kernel::Conv2dDgrad<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                        ElementCompute, ElementAccumulator, ConvertOp,
                        InnerProductOp, kThreadM, kThreadN, kCtaShapeM,
                        kCtaShapeN><<<grid, block, 0, stream>>>(
            problem_size, tensor_dy, tensor_w, tensor_dx_in, tensor_dx_out,
            alpha, beta);

    cudaError_t result = cudaPeekAtLastError();
    if (result != cudaSuccess) {
        return Status::kErrorInternal;
    }

    return Status::kSuccess;
}

/// Conv3d Dgrad dispatcher - dx = dgrad(dy, w)
template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator = ElementCompute,
          typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
          typename InnerProductOp = multiply_add<ElementAccumulator>>
Status Conv3dDgrad(conv::Conv3dProblemSize problem_size,
                   TensorRef<ElementA, LayoutA> tensor_dy,
                   TensorRef<ElementB, LayoutB> tensor_w,
                   TensorRef<ElementC, LayoutC> tensor_dx_in,
                   TensorRef<ElementC, LayoutC> tensor_dx_out,
                   ElementCompute alpha, ElementCompute beta,
                   cudaStream_t stream = nullptr) {
    //
    // Blocking factors improve performance of reference implementation
    //

    int const kThreadM = 2;  // shape of a thread's tile in the GEMM M dimension
    int const kThreadN = 4;  // shape of a thread's tile in the GEMM N dimension
    int const kCtaShapeM = 16;  // shape of a threadblock in units of threads
    int const kCtaShapeN = 8;   // shape of a threadblock in units of threads

    int64_t ndhw = int64_t(problem_size.N) * problem_size.D * problem_size.H *
                   problem_size.W;
    int64_t blocks_m =
            (ndhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM);

    dim3 block(kCtaShapeM, kCtaShapeN);
    dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) -
                                   1) / (kCtaShapeN * kThreadN));

    kernel::Conv3dDgrad<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                        ElementCompute, ElementAccumulator, ConvertOp,
                        InnerProductOp, kThreadM, kThreadN, kCtaShapeM,
                        kCtaShapeN><<<grid, block, 0, stream>>>(
            problem_size, tensor_dy, tensor_w, tensor_dx_in, tensor_dx_out,
            alpha, beta);

    cudaError_t result = cudaPeekAtLastError();
    if (result != cudaSuccess) {
        return Status::kErrorInternal;
    }

    return Status::kSuccess;
}

/// Conv2d Wgrad dispatcher - dw = wgrad(dy, x)
template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator = ElementCompute,
          typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
          typename InnerProductOp = multiply_add<ElementAccumulator>>
Status Conv2dWgrad(conv::Conv2dProblemSize problem_size,
                   TensorRef<ElementA, LayoutA> tensor_dy,
                   TensorRef<ElementB, LayoutB> tensor_x,
                   TensorRef<ElementC, LayoutC> tensor_dw_in,
                   TensorRef<ElementC, LayoutC> tensor_dw_out,
                   ElementCompute alpha, ElementCompute beta,
                   cudaStream_t stream = nullptr) {
    //
    // Blocking factors improve performance of reference implementation
    //

    int const kThreadM = 2;  // shape of a thread's tile in the GEMM M dimension
    int const kThreadN = 4;  // shape of a thread's tile in the GEMM N dimension
    int const kCtaShapeM = 8;   // shape of a threadblock in units of threads
    int const kCtaShapeN = 16;  // shape of a threadblock in units of threads

    int64_t rsc = int64_t(problem_size.R) * problem_size.S * problem_size.C;
    int64_t blocks_n =
            (rsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN);

    dim3 block(kCtaShapeM, kCtaShapeN);
    dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) /
                      (kCtaShapeM * kThreadM),
              uint32_t(blocks_n));

    kernel::Conv2dWgrad<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                        ElementCompute, ElementAccumulator, ConvertOp,
                        InnerProductOp, kThreadM, kThreadN, kCtaShapeM,
                        kCtaShapeN><<<grid, block, 0, stream>>>(
            problem_size, tensor_dy, tensor_x, tensor_dw_in, tensor_dw_out,
            alpha, beta);

    cudaError_t result = cudaPeekAtLastError();
    if (result != cudaSuccess) {
        return Status::kErrorInternal;
    }

    return Status::kSuccess;
}

/// Conv3d Wgrad dispatcher - dw = wgrad(dy, x)
template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator = ElementCompute,
          typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
          typename InnerProductOp = multiply_add<ElementAccumulator>>
Status Conv3dWgrad(conv::Conv3dProblemSize problem_size,
                   TensorRef<ElementA, LayoutA> tensor_dy,
                   TensorRef<ElementB, LayoutB> tensor_x,
                   TensorRef<ElementC, LayoutC> tensor_dw_in,
                   TensorRef<ElementC, LayoutC> tensor_dw_out,
                   ElementCompute alpha, ElementCompute beta,
                   cudaStream_t stream = nullptr) {
    //
    // Blocking factors improve performance of reference implementation
    //

    int const kThreadM = 2;  // shape of a thread's tile in the GEMM M dimension
    int const kThreadN = 4;  // shape of a thread's tile in the GEMM N dimension
    int const kCtaShapeM = 8;   // shape of a threadblock in units of threads
    int const kCtaShapeN = 16;  // shape of a threadblock in units of threads

    int64_t trsc = int64_t(problem_size.T) * problem_size.R * problem_size.S *
                   problem_size.C;
    int64_t blocks_n =
            (trsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN);

    dim3 block(kCtaShapeM, kCtaShapeN);
    dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) /
                      (kCtaShapeM * kThreadM),
              uint32_t(blocks_n));

    kernel::Conv3dWgrad<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                        ElementCompute, ElementAccumulator, ConvertOp,
                        InnerProductOp, kThreadM, kThreadN, kCtaShapeM,
                        kCtaShapeN><<<grid, block, 0, stream>>>(
            problem_size, tensor_dy, tensor_x, tensor_dw_in, tensor_dw_out,
            alpha, beta);

    cudaError_t result = cudaPeekAtLastError();
    if (result != cudaSuccess) {
        return Status::kErrorInternal;
    }

    return Status::kSuccess;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad.
template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator = ElementCompute,
          typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
          typename InnerProductOp = multiply_add<ElementAccumulator>>
Status Conv2d(conv::Operator convolutional_operator,
              conv::Conv2dProblemSize problem_size,
              TensorRef<ElementA, LayoutA> tensor_A,
              TensorRef<ElementB, LayoutB> tensor_B,
              TensorRef<ElementC, LayoutC> tensor_C,
              TensorRef<ElementC, LayoutC> tensor_D, ElementCompute alpha,
              ElementCompute beta, cudaStream_t stream = nullptr) {
    switch (convolutional_operator) {
        case conv::Operator::kFprop:
            return Conv2dFprop<ElementA, LayoutA, ElementB, LayoutB, ElementC,
                               LayoutC, ElementCompute, ElementAccumulator,
                               ConvertOp, InnerProductOp>(
                    problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha,
                    beta, stream);
            break;

        case conv::Operator::kDgrad:
            return Conv2dDgrad<ElementA, LayoutA, ElementB, LayoutB, ElementC,
                               LayoutC, ElementCompute, ElementAccumulator,
                               ConvertOp, InnerProductOp>(
                    problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha,
                    beta, stream);
            break;

        case conv::Operator::kWgrad:
            return Conv2dWgrad<ElementA, LayoutA, ElementB, LayoutB, ElementC,
                               LayoutC, ElementCompute, ElementAccumulator,
                               ConvertOp, InnerProductOp>(
                    problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha,
                    beta, stream);
            break;

        default:
            break;
    }

    return Status::kErrorNotSupported;
}

/// Generic 3D convolution targeting Conv3dFprop, Conv3dDgrad, and Conv3dWgrad.
template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator = ElementCompute,
          typename ConvertOp = NumericConverter<ElementC, ElementCompute>,
          typename InnerProductOp = multiply_add<ElementAccumulator>>
Status Conv3d(conv::Operator convolutional_operator,
              conv::Conv3dProblemSize problem_size,
              TensorRef<ElementA, LayoutA> tensor_A,
              TensorRef<ElementB, LayoutB> tensor_B,
              TensorRef<ElementC, LayoutC> tensor_C,
              TensorRef<ElementC, LayoutC> tensor_D, ElementCompute alpha,
              ElementCompute beta, cudaStream_t stream = nullptr) {
    switch (convolutional_operator) {
        case conv::Operator::kFprop:
            return Conv3dFprop<ElementA, LayoutA, ElementB, LayoutB, ElementC,
                               LayoutC, ElementCompute, ElementAccumulator,
                               ConvertOp, InnerProductOp>(
                    problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha,
                    beta, stream);

        case conv::Operator::kDgrad:
            return Conv3dDgrad<ElementA, LayoutA, ElementB, LayoutB, ElementC,
                               LayoutC, ElementCompute, ElementAccumulator,
                               ConvertOp, InnerProductOp>(
                    problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha,
                    beta, stream);

        case conv::Operator::kWgrad:
            return Conv3dWgrad<ElementA, LayoutA, ElementB, LayoutB, ElementC,
                               LayoutC, ElementCompute, ElementAccumulator,
                               ConvertOp, InnerProductOp>(
                    problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha,
                    beta, stream);

        default:
            break;
    }

    return Status::kErrorNotSupported;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace device
}  // namespace reference
}  // namespace cutlass

////////////////////////////////////////////////////////////////////////////////////////////////////
